from ActualCausal.Inference.compute_inference import evaluate_inference
import copy
from ActualCausal.Updater.update_params import compute_params
from ActualCausal.Utils.Logging.inter_logger import InterLogger
from ActualCausal.Utils.weighting import separate_weights, get_weights
from ActualCausal.Utils.run_dataset import get_operation, compute_types
from ActualCausal.Train.train_model import train_model
from ActualCausal.Updater.update_params import update_net_parameters
from Model.model_utils import save_model, load_model
from tianshou.data import Batch
import os, time
import numpy as np
from Record.file_management import load_from_pickle, save_to_pickle
from ACState.object_dict import ObjDict
from ActualCausal.Utils.Logging.logging_names import record_names, infer_names

def init_loggers(args, extractor, name="", wdb_run = None, train_logging_names = None, test_logging_names=None):
    train_logging_names, test_logging_names = (record_names, infer_names) if train_logging_names is None else train_logging_names, test_logging_names
    name_idx = extractor.get_index(args.inter.train_names[0]) if len(args.inter.train_names) > 0 else -1
    train_logger = InterLogger(name + "tr", name_idx, args.record.record_graphs, args.train.log_interval, record_names, args.record.log_filename, denorm=False, wdb_logger=wdb_run)
    intermediate_logger = InterLogger(name + "intermediate", name_idx, args.record.record_graphs, args.train.intermediate_log_interval, record_names, args.record.log_filename, denorm=False, wdb_logger=wdb_run) if args.train.intermediate_log_interval > 0 else None
    train_inference_logger = InterLogger(name + "tr_infer", name_idx, args.record.record_graphs, args.train.log_interval, infer_names, args.record.log_filename, denorm=False, wdb_logger=wdb_run)
    test_inference_logger = InterLogger(name + "tst_infer", name_idx, args.record.record_graphs, args.train.log_interval, infer_names, args.record.log_filename, denorm=False, wdb_logger=wdb_run)
    return train_logger, train_inference_logger, test_inference_logger, intermediate_logger

def pretrain(args, model, train_buffer, test_buffer=None, wrap=None, wdb_run = None):
    # set the arguments to the pretrain ones, then calls the train loop
    pretrain_args = copy.deepcopy(args)
    pretrain_args.infer_dataset = False
    pretrain_args.train.num_iters = args.pretrain.num_iters
    pretrain_args.train.log_interval = args.pretrain.pretrain_log_interval
    pretrain_args.train.intermediate_log_interval = -1
    pretrain_args.infer.infer_interval = args.pretrain.pretrain_infer_interval
    pretrain_args.active.full_steps = args.pretrain.pretrain_full_steps
    pretrain_args.active.active_steps = args.pretrain.pretrain_active_steps
    pretrain_args.active.trace_steps = args.pretrain.pretrain_trace_steps
    pretrain_args.inter.train_forms = args.inter.pretrain_forms
    values, weights, binaries = separate_weights(args, args.pretrain.weighting_type, model, train_buffer)
    train_buffer.weight_binary[:len(train_buffer)] = binaries
    # print("binaries", args.pretrain.weighting_type, train_buffer.weight_binary[:1000].squeeze())
    train_loop(pretrain_args, model, train_buffer, test_buffer=test_buffer, pretrain=True, wrap_function=wrap, wdb_run=wdb_run)

    # evaluate the passive and active likelihoods after training
    active_like, passive_like = None, None
    if len(args.inter.train_names) > 0:
        if "single_passive" in args.inter.pretrain_forms: passive_like = get_operation(model, train_buffer, all_compute=[compute_types.PASSIVE_LIKELIHOOD], reduced=True, normalized=False, object_names=args.inter.train_names)
        if "full" in args.inter.pretrain_forms or "mask" in args.inter.pretrain_forms: active_like = get_operation(model, train_buffer, all_compute=[compute_types.ACTIVE_OPEN_LIKELIHOOD], reduced=True, normalized=False, object_names=args.inter.train_names)
    else:
        if "passive" in args.inter.pretrain_forms: 
            passive_like = get_operation(model, train_buffer, all_compute=[compute_types.ALL_PASSIVE_LIKELIHOOD], reduced=True, normalized=False, object_names=args.inter.train_names)
            passive_like = np.sum(passive_like, axis=-1)
        if "all" in args.inter.pretrain_forms or "all_mask" in args.inter.pretrain_forms: 
            active_like = get_operation(model, train_buffer, all_compute=[compute_types.ALL_ACTIVE_OPEN_LIKELIHOOD], reduced=True, normalized=False, object_names=args.inter.train_names)
            active_like = np.sum(active_like, axis=-1)
    return passive_like, active_like

def train_inference(args, model, train_buffer, test_buffer, wdb_run=None):
    # if len(args.record.load_intermediate): 
    #     binaries = load_from_pickle(os.path.join(args.record.load_intermediate, args.environment.env + "_weights.pkl"))
    # else: 
    values, weights, binaries = separate_weights(args, args.inter.weighting_type, model, train_buffer)
    if len(args.record.save_intermediate):
        save_to_pickle(os.path.join(args.record.save_intermediate, args.environment.env +  "_weights.pkl"), binaries)
    train_buffer.weight_binary[:len(train_buffer)] = binaries
    return train_loop(args, model, train_buffer, test_buffer, pretrain=False, wdb_run = wdb_run, save_interval=args.record.save_interval)

def train_loop(args, model, train_buffer, test_buffer=None, pretrain=False, wrap_function=None, wdb_run=None, save_interval = 0):
    device = args.torch.gpu if args.torch.cuda else "cpu"
    log_batch = ["trace", "valid"] # TODO: make this an input parameter
    params = compute_params(0, args, train_buffer, pretrain=pretrain, result=None, model = model)
    # generate weighting for the evaluation, if used, TODO: probably could move this out if it gets difficult
    params.eval_trace_weights, params.train_trace_weights = None, None
    if args.infer.eval_weight_infer == "trace_weights":
        v, w, eval_trace_binaries = separate_weights(args, "trace", model, test_buffer)
        params.eval_trace_weights = get_weights(args.infer.eval_weight_lambda, eval_trace_binaries)
    if args.infer.train_weight_infer == "train_trace_weights":
        v, w, train_trace_binaries = separate_weights(args, "trace", model, train_buffer)
        params.train_trace_weights = get_weights(args.infer.eval_weight_lambda, train_trace_binaries)

    train_logger, train_inference_logger, test_inference_logger, intermediate_logger = init_loggers(args, model.extractor, name = "pre" if pretrain else "main", wdb_run=wdb_run)
    test_result = None
    for i in range(args.train.num_iters):
        # update any of the NONADAPTIVE schedules or parameters
        start = time.time()
        result = train_model(i, args, params, model, train_buffer, log_batch=log_batch, wrap_function=wrap_function, intermediate_logger=intermediate_logger)
        # print("train", time.time() - start)
        params = compute_params(i, args, train_buffer, pretrain=pretrain, result=result, params=params, model = model)
        update_net_parameters(i, model, args, params)
        # print("params", time.time() - start)
        train_result = evaluate_inference(i, args, params, model, train_buffer, wrap_function=wrap_function, weights=params[args.infer.train_weight_infer] if args.infer.train_weight_infer in params and len(args.infer.train_weight_infer) > 0 else None)
        if test_buffer is not None: test_result = evaluate_inference(i, args, params, model, test_buffer, test=True, wrap_function=wrap_function, weights=params.eval_trace_weights)
        train_logger.log(i, result)
        # print("training", i)
        if train_result is not None: 
            print("train_inference_logging", train_inference_logger.log_interval, i)
            train_inference_logger.log(i, train_result)
        if test_result is not None: test_inference_logger.log(i, test_result)
        if save_interval > 0 and i > 0 and i % save_interval == 0: 
            save_model(model, args.record.save_dir)
            del model.full_models[model.train_names[0]]
            model = load_model(model, args.record.save_dir, device=device)
        # print(i, time.time() - start)
        # print(i, time.time() - start)
    params.infer_num = 0
    if args.infer.infer_dataset:
        train_result, test_result = None, None
        train_result = evaluate_inference(i, args, params, model, train_buffer)
        # test_result = evaluate_inference(i, args, params, model, test_buffer, test=True)
    return train_result, test_result # additional contains passive and passive_weights

def test_dataset(args, model, buffer, extractor, test=False, weights=None, compute_names=[], wdb_run=None, dataset_log_names=None):
    name_idx = extractor.get_index(args.inter.train_names[0]) if len(args.inter.train_names) > 0 else -1
    dataset_log_names = infer_names if dataset_log_names is None else dataset_log_names
    dataset_logger = InterLogger("dataset_logger", name_idx, args.record.record_graphs, 1, dataset_log_names, args.record.log_filename, denorm=False, wdb_logger=wdb_run)
    
    result = evaluate_inference(0, args, ObjDict({"infer_num": args.infer.infer_num, "mask_mode": args.infer.train_mask_mode}), model, buffer, test=test, weights=weights)
    for cn in compute_names:
        result[cn] = get_operation(model, buffer, all_compute=[cn], reduced=False, keep_all=False)
    dataset_logger.log(0, result)
    if args.infer.eval_weight_infer == "trace_weights":
        v, w, eval_trace_binaries = separate_weights(args, "trace", model, buffer)
        eval_trace_weights = get_weights(args.infer.eval_weight_lambda, eval_trace_binaries)
        weighted_result = evaluate_inference(0, args, ObjDict({"infer_num": args.infer.infer_num, "mask_mode": args.infer.train_mask_mode}), model, buffer, test=test, weights=eval_trace_weights)
        dataset_logger.log(1, weighted_result)

    return result